//
// Copyright 2023 The GUAC Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package keyvalue

import (
	"context"
	"errors"
	"fmt"
	"sort"
	"strings"

	"github.com/vektah/gqlparser/v2/gqlerror"

	"github.com/guacsec/guac/internal/testing/ptrfrom"
	"github.com/guacsec/guac/pkg/assembler/graphql/model"
	"github.com/guacsec/guac/pkg/assembler/kv"
)

const noVulnType string = "novuln"

// Internal data: Vulnerability
type vulnTypeStruct struct {
	ThisID  string
	Type    string
	VulnIDs []string
}
type vulnIDNode struct {
	ThisID            string
	Parent            string
	VulnID            string
	CertifyVulnLinks  []string
	VulnEqualLinks    []string
	VexLinks          []string
	VulnMetadataLinks []string
}

func (n *vulnTypeStruct) ID() string { return n.ThisID }
func (n *vulnIDNode) ID() string     { return n.ThisID }

func (n *vulnTypeStruct) Key() string {
	return hashKey(n.Type)
}

func (n *vulnIDNode) Key() string {
	return hashKey(strings.Join([]string{
		n.Parent,
		n.VulnID,
	}, ":"))
}

func (n *vulnTypeStruct) Neighbors(allowedEdges edgeMap) []string {
	if allowedEdges[model.EdgeVulnerabilityTypeVulnerabilityID] {
		return n.VulnIDs
	}
	return nil
}

func (n *vulnIDNode) Neighbors(allowedEdges edgeMap) []string {
	var out []string
	if allowedEdges[model.EdgeVulnerabilityIDVulnerabilityType] {
		out = append(out, n.Parent)
	}
	if allowedEdges[model.EdgeVulnerabilityCertifyVuln] {
		out = append(out, n.CertifyVulnLinks...)
	}
	if allowedEdges[model.EdgeVulnerabilityVulnEqual] {
		out = append(out, n.VulnEqualLinks...)
	}
	if allowedEdges[model.EdgeVulnerabilityCertifyVexStatement] {
		out = append(out, n.VexLinks...)
	}
	if allowedEdges[model.EdgeVulnMetadataVulnerability] {
		out = append(out, n.VulnMetadataLinks...)
	}

	return out
}

func (n *vulnTypeStruct) BuildModelNode(ctx context.Context, c *demoClient) (model.Node, error) {
	return c.buildVulnResponse(ctx, n.ThisID, nil)
}
func (n *vulnIDNode) BuildModelNode(ctx context.Context, c *demoClient) (model.Node, error) {
	return c.buildVulnResponse(ctx, n.ThisID, nil)
}

// certifyVulnerability back edges
func (n *vulnIDNode) setVulnerabilityLinks(ctx context.Context, id string, c *demoClient) error {
	n.CertifyVulnLinks = append(n.CertifyVulnLinks, id)
	return setkv(ctx, vulnIDCol, n, c)
}

// equalVulnerability back edges
func (n *vulnIDNode) setVulnEqualLinks(ctx context.Context, id string, c *demoClient) error {
	n.VulnEqualLinks = append(n.VulnEqualLinks, id)
	return setkv(ctx, vulnIDCol, n, c)
}

// certifyVexStatement back edges
func (n *vulnIDNode) setVexLinks(ctx context.Context, id string, c *demoClient) error {
	n.VexLinks = append(n.VexLinks, id)
	return setkv(ctx, vulnIDCol, n, c)
}

// vulnerability Metadata back edges
func (n *vulnIDNode) setVulnMetadataLinks(ctx context.Context, id string, c *demoClient) error {
	n.VulnMetadataLinks = append(n.VulnMetadataLinks, id)
	return setkv(ctx, vulnIDCol, n, c)
}

func (n *vulnTypeStruct) addVulnID(ctx context.Context, vulnID string, c *demoClient) error {
	n.VulnIDs = append(n.VulnIDs, vulnID)
	return setkv(ctx, vulnTypeCol, n, c)
}

// Ingest Vulnerabilities

func (c *demoClient) IngestVulnerabilities(ctx context.Context, vulns []*model.IDorVulnerabilityInput) ([]*model.VulnerabilityIDs, error) {
	var modelVulnerabilities []*model.VulnerabilityIDs
	for _, vuln := range vulns {
		modelVuln, err := c.IngestVulnerability(ctx, *vuln)
		if err != nil {
			return nil, gqlerror.Errorf("IngestVulnerability failed with err: %v", err)
		}
		modelVulnerabilities = append(modelVulnerabilities, modelVuln)
	}
	return modelVulnerabilities, nil
}

func (c *demoClient) IngestVulnerability(ctx context.Context, input model.IDorVulnerabilityInput) (*model.VulnerabilityIDs, error) {
	inType := &vulnTypeStruct{
		Type: strings.ToLower(input.VulnerabilityInput.Type),
	}
	c.m.RLock()
	outType, err := byKeykv[*vulnTypeStruct](ctx, vulnTypeCol, inType.Key(), c)
	c.m.RUnlock()
	if err != nil {
		if !errors.Is(err, kv.NotFoundError) {
			return nil, err
		}
		c.m.Lock()
		outType, err = byKeykv[*vulnTypeStruct](ctx, vulnTypeCol, inType.Key(), c)
		if err != nil {
			if !errors.Is(err, kv.NotFoundError) {
				c.m.Unlock()
				return nil, err
			}
			inType.ThisID = c.getNextID()
			if err := c.addToIndex(ctx, vulnTypeCol, inType); err != nil {
				c.m.Unlock()
				return nil, err
			}
			if err := setkv(ctx, vulnTypeCol, inType, c); err != nil {
				c.m.Unlock()
				return nil, err
			}
			outType = inType
		}
		c.m.Unlock()
	}

	inVulnID := &vulnIDNode{
		Parent: outType.ThisID,
		VulnID: strings.ToLower(input.VulnerabilityInput.VulnerabilityID),
	}
	c.m.RLock()
	outVulnID, err := byKeykv[*vulnIDNode](ctx, vulnIDCol, inVulnID.Key(), c)
	c.m.RUnlock()
	if err != nil {
		if !errors.Is(err, kv.NotFoundError) {
			return nil, err
		}
		c.m.Lock()
		outVulnID, err = byKeykv[*vulnIDNode](ctx, vulnIDCol, inVulnID.Key(), c)
		if err != nil {
			if !errors.Is(err, kv.NotFoundError) {
				c.m.Unlock()
				return nil, err
			}
			inVulnID.ThisID = c.getNextID()
			if err := c.addToIndex(ctx, vulnIDCol, inVulnID); err != nil {
				c.m.Unlock()
				return nil, err
			}
			if err := setkv(ctx, vulnIDCol, inVulnID, c); err != nil {
				c.m.Unlock()
				return nil, err
			}
			if err := outType.addVulnID(ctx, inVulnID.ThisID, c); err != nil {
				c.m.Unlock()
				return nil, err
			}
			outVulnID = inVulnID
		}
		c.m.Unlock()
	}

	return &model.VulnerabilityIDs{
		VulnerabilityTypeID: outType.ThisID,
		VulnerabilityNodeID: outVulnID.ThisID,
	}, nil
}

// Query Vulnerabilities

func (c *demoClient) VulnerabilityList(ctx context.Context, vulnSpec model.VulnerabilitySpec, after *string, first *int) (*model.VulnerabilityConnection, error) {
	c.m.RLock()
	defer c.m.RUnlock()
	if vulnSpec.ID != nil {
		v, err := c.buildVulnResponse(ctx, *vulnSpec.ID, &vulnSpec)
		if err != nil {
			if errors.Is(err, errNotFound) {
				// not found
				return nil, nil
			}
			return nil, err
		}

		return &model.VulnerabilityConnection{
			TotalCount: 1,
			PageInfo: &model.PageInfo{
				HasNextPage: false,
				StartCursor: ptrfrom.String(v.ID),
				EndCursor:   ptrfrom.String(v.ID),
			},
			Edges: []*model.VulnerabilityEdge{
				{
					Cursor: v.ID,
					Node:   v,
				},
			},
		}, nil
	}

	edges := make([]*model.VulnerabilityEdge, 0)
	hasNextPage := false
	numNodes := 0
	totalCount := 0
	addToCount := 0

	if vulnSpec.NoVuln != nil && !*vulnSpec.NoVuln {
		if vulnSpec.Type != nil && *vulnSpec.Type == noVulnType {
			return nil, gqlerror.Errorf("novuln boolean set to false, cannot specify vulnerability type to be novuln")
		}
	}

	// if novuln is specified, retrieve all "novuln" type nodes
	if vulnSpec.NoVuln != nil && *vulnSpec.NoVuln {
		vulnSpec.Type = ptrfrom.String(noVulnType)
		vulnSpec.VulnerabilityID = ptrfrom.String("")
	}

	if vulnSpec.Type != nil {
		inType := &vulnTypeStruct{
			Type: strings.ToLower(*vulnSpec.Type),
		}
		typeStruct, err := byKeykv[*vulnTypeStruct](ctx, vulnTypeCol, inType.Key(), c)
		if err == nil {
			vulnIDs := c.buildVulnID(ctx, typeStruct, &vulnSpec)
			for _, id := range vulnIDs {
				v := &model.Vulnerability{
					ID:   typeStruct.ThisID,
					Type: typeStruct.Type,
					VulnerabilityIDs: []*model.VulnerabilityID{
						id,
					},
				}

				if (after != nil && vulnIDs[0].ID > *after) || after == nil {
					addToCount += 1

					if first != nil {
						if numNodes < *first {
							edges = append(edges, &model.VulnerabilityEdge{
								Cursor: vulnIDs[0].ID,
								Node:   v,
							})
							numNodes++
						} else if numNodes == *first {
							hasNextPage = true
						}
					} else {
						edges = append(edges, &model.VulnerabilityEdge{
							Cursor: vulnIDs[0].ID,
							Node:   v,
						})
					}
				}
			}
		}
	} else {
		currentPage := false

		if after == nil {
			currentPage = true
		}

		var done bool
		scn := c.kv.Keys(vulnTypeCol)
		for !done {
			var typeKeys []string
			var err error
			typeKeys, done, err = scn.Scan(ctx)
			if err != nil {
				return nil, err
			}

			sort.Strings(typeKeys)
			totalCount = len(typeKeys)

			for i, tk := range typeKeys {
				typeStruct, err := byKeykv[*vulnTypeStruct](ctx, vulnTypeCol, tk, c)
				if err != nil {
					return nil, err
				}
				vulnIDs := c.buildVulnID(ctx, typeStruct, &vulnSpec)
				if len(vulnIDs) > 0 {
					for _, id := range vulnIDs {
						v := &model.Vulnerability{
							ID:   typeStruct.ThisID,
							Type: typeStruct.Type,
							VulnerabilityIDs: []*model.VulnerabilityID{
								id,
							},
						}

						if after != nil && !currentPage {
							if id.ID == *after {
								totalCount = len(typeKeys) - (i + 1)
								currentPage = true
							}
							continue
						}

						if first != nil {
							if numNodes < *first {
								edges = append(edges, &model.VulnerabilityEdge{
									Cursor: id.ID,
									Node:   v,
								})
								numNodes++
							} else if numNodes == *first {
								hasNextPage = true
							}
						} else {
							edges = append(edges, &model.VulnerabilityEdge{
								Cursor: id.ID,
								Node:   v,
							})
						}
					}
				}
			}
		}
	}

	if len(edges) != 0 {
		return &model.VulnerabilityConnection{
			TotalCount: totalCount + addToCount,
			PageInfo: &model.PageInfo{
				HasNextPage: hasNextPage,
				StartCursor: ptrfrom.String(edges[0].Node.ID),
				EndCursor:   ptrfrom.String(edges[max(numNodes-1, 0)].Node.ID),
			},
			Edges: edges,
		}, nil
	}
	return nil, nil
}

func (c *demoClient) Vulnerabilities(ctx context.Context, filter *model.VulnerabilitySpec) ([]*model.Vulnerability, error) {
	c.m.RLock()
	defer c.m.RUnlock()
	if filter != nil && filter.ID != nil {
		v, err := c.buildVulnResponse(ctx, *filter.ID, filter)
		if err != nil {
			if errors.Is(err, errNotFound) {
				// not found
				return nil, nil
			}
			return nil, err
		}
		return []*model.Vulnerability{v}, nil
	}

	if filter.NoVuln != nil && !*filter.NoVuln {
		if filter.Type != nil && *filter.Type == noVulnType {
			return []*model.Vulnerability{}, gqlerror.Errorf("novuln boolean set to false, cannot specify vulnerability type to be novuln")
		}
	}

	out := []*model.Vulnerability{}
	// if novuln is specified, retrieve all "novuln" type nodes
	if filter != nil && filter.NoVuln != nil && *filter.NoVuln {
		filter.Type = ptrfrom.String(noVulnType)
		filter.VulnerabilityID = ptrfrom.String("")
	}

	if filter != nil && filter.Type != nil {
		inType := &vulnTypeStruct{
			Type: strings.ToLower(*filter.Type),
		}
		typeStruct, err := byKeykv[*vulnTypeStruct](ctx, vulnTypeCol, inType.Key(), c)
		if err == nil {
			vulnIDs := c.buildVulnID(ctx, typeStruct, filter)
			if len(vulnIDs) > 0 {
				out = append(out, &model.Vulnerability{
					ID:               typeStruct.ThisID,
					Type:             typeStruct.Type,
					VulnerabilityIDs: vulnIDs,
				})
			}
		}
	} else {
		var done bool
		scn := c.kv.Keys(vulnTypeCol)
		for !done {
			var typeKeys []string
			var err error
			typeKeys, done, err = scn.Scan(ctx)
			if err != nil {
				return nil, err
			}
			for _, tk := range typeKeys {
				typeStruct, err := byKeykv[*vulnTypeStruct](ctx, vulnTypeCol, tk, c)
				if err != nil {
					return nil, err
				}
				vulnIDs := c.buildVulnID(ctx, typeStruct, filter)
				if len(vulnIDs) > 0 {
					out = append(out, &model.Vulnerability{
						ID:               typeStruct.ThisID,
						Type:             typeStruct.Type,
						VulnerabilityIDs: vulnIDs,
					})
				}
			}
		}
	}
	return out, nil
}

func (c *demoClient) buildVulnID(ctx context.Context, typeStruct *vulnTypeStruct, filter *model.VulnerabilitySpec) []*model.VulnerabilityID {
	if filter != nil && filter.VulnerabilityID != nil {
		inVulnID := &vulnIDNode{
			Parent: typeStruct.ThisID,
			VulnID: strings.ToLower(*filter.VulnerabilityID),
		}
		outVulnID, err := byKeykv[*vulnIDNode](ctx, vulnIDCol, inVulnID.Key(), c)
		if err != nil {
			return nil
		}
		return []*model.VulnerabilityID{{
			ID:              outVulnID.ThisID,
			VulnerabilityID: outVulnID.VulnID,
		}}
	}
	vunIDs := []*model.VulnerabilityID{}
	for _, vulnIDID := range typeStruct.VulnIDs {
		v, err := byIDkv[*vulnIDNode](ctx, vulnIDID, c)
		if err != nil {
			return nil
		}
		if filter != nil && noMatch(toLower(filter.VulnerabilityID), v.VulnID) {
			continue
		}
		vunIDs = append(vunIDs, &model.VulnerabilityID{
			ID:              v.ThisID,
			VulnerabilityID: v.VulnID,
		})
	}
	return vunIDs
}

func (c *demoClient) exactVulnerability(ctx context.Context, filter *model.VulnerabilitySpec) (*vulnIDNode, error) {
	if filter == nil {
		return nil, nil
	}
	if filter.ID != nil {
		if v, err := byIDkv[*vulnIDNode](ctx, *filter.ID, c); err == nil {
			return v, nil
		} else {
			if !errors.Is(err, kv.NotFoundError) && !errors.Is(err, errTypeNotMatch) {
				return nil, err
			}
			return nil, nil
		}
	}
	if filter.Type != nil && filter.VulnerabilityID != nil {
		inType := &vulnTypeStruct{
			Type: strings.ToLower(*filter.Type),
		}
		typeStruct, err := byKeykv[*vulnTypeStruct](ctx, vulnTypeCol, inType.Key(), c)
		if err != nil {
			if !errors.Is(err, kv.NotFoundError) && !errors.Is(err, errTypeNotMatch) {
				return nil, err
			}
			return nil, nil
		}

		inVulnID := &vulnIDNode{
			Parent: typeStruct.ThisID,
			VulnID: strings.ToLower(*filter.VulnerabilityID),
		}
		vulnID, err := byKeykv[*vulnIDNode](ctx, vulnIDCol, inVulnID.Key(), c)
		if err != nil {
			if !errors.Is(err, kv.NotFoundError) && !errors.Is(err, errTypeNotMatch) {
				return nil, err
			}
			return nil, nil
		}
		return vulnID, nil
	}
	return nil, nil
}

// Builds a model.Vulnerability to send as GraphQL response, starting from id.
// The optional filter allows restricting output (on selection operations).
func (c *demoClient) buildVulnResponse(ctx context.Context, id string, filter *model.VulnerabilitySpec) (*model.Vulnerability, error) {
	if filter != nil && filter.ID != nil && *filter.ID != id {
		return nil, nil
	}

	currentID := id

	var vl []*model.VulnerabilityID
	if vulnNode, err := byIDkv[*vulnIDNode](ctx, currentID, c); err == nil {
		if filter != nil && noMatch(toLower(filter.VulnerabilityID), vulnNode.VulnID) {
			return nil, nil
		}
		vl = append(vl, &model.VulnerabilityID{
			// IDs are generated as string even though we ask for integers
			// See https://github.com/99designs/gqlgen/issues/2561
			ID:              vulnNode.ThisID,
			VulnerabilityID: vulnNode.VulnID,
		})
		currentID = vulnNode.Parent
	} else if !errors.Is(err, kv.NotFoundError) && !errors.Is(err, errTypeNotMatch) {
		return nil, fmt.Errorf("Error retrieving node for id: %v : %w", currentID, err)
	}

	typeStruct, err := byIDkv[*vulnTypeStruct](ctx, currentID, c)
	if err != nil {
		if errors.Is(err, kv.NotFoundError) || errors.Is(err, errTypeNotMatch) {
			return nil, fmt.Errorf("%w: ID does not match expected node type for vulnerability", errNotFound)
		} else {
			return nil, fmt.Errorf("Error retrieving node for id: %v : %w", currentID, err)
		}
	}
	if filter != nil && noMatch(toLower(filter.Type), typeStruct.Type) {
		return nil, nil
	}
	v := model.Vulnerability{
		ID:               typeStruct.ThisID,
		Type:             typeStruct.Type,
		VulnerabilityIDs: vl,
	}
	return &v, nil
}

func (c *demoClient) getVulnerabilityFromInput(ctx context.Context, input model.VulnerabilityInputSpec) (*vulnIDNode, error) {
	inType := &vulnTypeStruct{
		Type: strings.ToLower(input.Type),
	}
	typeStruct, err := byKeykv[*vulnTypeStruct](ctx, vulnTypeCol, inType.Key(), c)
	if err != nil {
		return nil, err
	}

	inVulnID := &vulnIDNode{
		Parent: typeStruct.ThisID,
		VulnID: strings.ToLower(input.VulnerabilityID),
	}
	vulnID, err := byKeykv[*vulnIDNode](ctx, vulnIDCol, inVulnID.Key(), c)
	if err != nil {
		return nil, err
	}
	return vulnID, nil
}

// returnFoundVulnerability return the node by first searching via ID. If the ID is not specified, it defaults to searching via inputspec
func (c *demoClient) returnFoundVulnerability(ctx context.Context, vulnIDorInput *model.IDorVulnerabilityInput) (*vulnIDNode, error) {
	if vulnIDorInput.VulnerabilityNodeID != nil {
		foundVulnID, err := byIDkv[*vulnIDNode](ctx, *vulnIDorInput.VulnerabilityNodeID, c)
		if err != nil {
			return nil, gqlerror.Errorf("failed to return vulnIDNode node by ID with error: %v", err)
		}
		return foundVulnID, nil
	} else {
		foundVulnID, err := c.getVulnerabilityFromInput(ctx, *vulnIDorInput.VulnerabilityInput)
		if err != nil {
			return nil, gqlerror.Errorf("failed to getVulnerabilityFromInput with error: %v", err)
		}
		return foundVulnID, nil
	}
}
